import cv2
import torch
import torchvision.transforms as transforms
from model import resnet50
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载模型
model = resnet50(num_classes=3).to(device)
# load model weights

weights_path = "./resNet50.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))
# model = torch.load('resnet_model.pth')
model.eval()

# 定义场景类别
classes = ['1', '2', '3']

# 定义场景切换次数统计变量
scene_counts = [0, 0, 0]

# 定义图像预处理转换
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 打开视频文件
cap = cv2.VideoCapture('1.mp4')

# 读取视频帧并进行分类
previous_scene = None
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # 对图像进行预处理
    frame1 = transform(frame).unsqueeze(0)

    # 使用模型进行分类
    # with torch.no_grad():
    #     output = model(input_tensor)
    #     _, predicted = torch.max(output, 1)
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(frame1.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
        scene = classes[predict_cla.item()]

    # 统计切换次数
    if previous_scene is not None and previous_scene != scene:
        scene_counts[classes.index(previous_scene)] += 1
        print("第{}类 次数{}".format(classes.index(previous_scene)+1,scene_counts[classes.index(previous_scene)] ))

    # 更新上一帧场景
    previous_scene = scene

    # 显示当前场景
    cv2.putText(frame, scene, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    cv2.imshow('Video', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 输出切换次数统计结果
for i, scene in enumerate(classes):
    print(f'{scene}切换次数：{scene_counts[i]}')

# 释放资源
cap.release()
cv2.destroyAllWindows()